// Copyright Epic Games, Inc. All Rights Reserved.

#include "dlltransport.h"

#include <zencore/except.h>
#include <zencore/fmtutils.h>
#include <zencore/logging.h>
#include <zencore/scopeguard.h>

#include <exception>
#include <thread>
#include <vector>

ZEN_THIRD_PARTY_INCLUDES_START
#include <fmt/format.h>
ZEN_THIRD_PARTY_INCLUDES_END

#if ZEN_WITH_PLUGINS

namespace zen {

//////////////////////////////////////////////////////////////////////////

struct LoadedDll
{
	std::string			  Name;
	std::filesystem::path LoadedFromPath;
	Ref<TransportPlugin>  Plugin;
};

class DllTransportPluginImpl : public DllTransportPlugin, RefCounted
{
public:
	DllTransportPluginImpl();
	~DllTransportPluginImpl();

	virtual uint32_t	AddRef() const override;
	virtual uint32_t	Release() const override;
	virtual void		Configure(const char* OptionTag, const char* OptionValue) override;
	virtual void		Initialize(TransportServer* ServerInterface) override;
	virtual void		Shutdown() override;
	virtual const char* GetDebugName() override;
	virtual bool		IsAvailable() override;

	virtual void LoadDll(std::string_view Name) override;
	virtual void ConfigureDll(std::string_view Name, const char* OptionTag, const char* OptionValue) override;

private:
	TransportServer*	   m_ServerInterface = nullptr;
	RwLock				   m_Lock;
	std::vector<LoadedDll> m_Transports;
};

DllTransportPluginImpl::DllTransportPluginImpl()
{
}

DllTransportPluginImpl::~DllTransportPluginImpl()
{
}

uint32_t
DllTransportPluginImpl::AddRef() const
{
	return RefCounted::AddRef();
}

uint32_t
DllTransportPluginImpl::Release() const
{
	return RefCounted::Release();
}

void
DllTransportPluginImpl::Configure(const char* OptionTag, const char* OptionValue)
{
	// No configuration options

	ZEN_UNUSED(OptionTag, OptionValue);
}

void
DllTransportPluginImpl::Initialize(TransportServer* ServerIface)
{
	m_ServerInterface = ServerIface;

	RwLock::ExclusiveLockScope _(m_Lock);

	for (LoadedDll& Transport : m_Transports)
	{
		try
		{
			Transport.Plugin->Initialize(ServerIface);
		}
		catch (const std::exception&)
		{
			// TODO: report
		}
	}
}

void
DllTransportPluginImpl::Shutdown()
{
	RwLock::ExclusiveLockScope _(m_Lock);

	for (LoadedDll& Transport : m_Transports)
	{
		try
		{
			Transport.Plugin->Shutdown();
		}
		catch (const std::exception&)
		{
			// TODO: report
		}
	}
}

const char*
DllTransportPluginImpl::GetDebugName()
{
	return nullptr;
}

bool
DllTransportPluginImpl::IsAvailable()
{
	return true;
}

void
DllTransportPluginImpl::ConfigureDll(std::string_view Name, const char* OptionTag, const char* OptionValue)
{
	RwLock::ExclusiveLockScope _(m_Lock);

	for (auto& Transport : m_Transports)
	{
		if (Transport.Name == Name)
		{
			Transport.Plugin->Configure(OptionTag, OptionValue);
		}
	}
}

void
DllTransportPluginImpl::LoadDll(std::string_view Name)
{
	RwLock::ExclusiveLockScope _(m_Lock);

	ExtendableStringBuilder<128> DllPath;
	DllPath << Name << ".dll";
	HMODULE DllHandle = LoadLibraryA(DllPath.c_str());

	if (!DllHandle)
	{
		std::error_code Ec = MakeErrorCodeFromLastError();

		throw std::system_error(Ec, fmt::format("failed to load transport DLL from '{}'", DllPath));
	}

	TransportPlugin* CreateTransportPlugin();

	PfnCreateTransportPlugin CreatePlugin = (PfnCreateTransportPlugin)GetProcAddress(DllHandle, "CreateTransportPlugin");

	if (!CreatePlugin)
	{
		std::error_code Ec = MakeErrorCodeFromLastError();

		FreeLibrary(DllHandle);

		throw std::system_error(Ec, fmt::format("API mismatch detected in transport DLL loaded from '{}'", DllPath));
	}

	LoadedDll NewDll;

	NewDll.Name			  = Name;
	NewDll.LoadedFromPath = DllPath.c_str();
	NewDll.Plugin		  = CreatePlugin();

	m_Transports.emplace_back(std::move(NewDll));
}

DllTransportPlugin*
CreateDllTransportPlugin()
{
	return new DllTransportPluginImpl;
}

}  // namespace zen

#endif
